{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import torch\n",
    "import numpy as np\n",
    "import PIL\n",
    "from PIL import Image\n",
    "from IPython.display import HTML\n",
    "from pyramid_dit import PyramidDiTForVideoGeneration\n",
    "from IPython.display import Image as ipython_image\n",
    "from diffusers.utils import load_image, export_to_video, export_to_gif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "variant='diffusion_transformer_image'       # For low resolution\n",
    "model_name = \"pyramid_flux\"\n",
    "\n",
    "model_path = \"/home/jinyang06/models/pyramid-flow-miniflux\"   # The downloaded checkpoint dir\n",
    "model_dtype = 'bf16'\n",
    "\n",
    "device_id = 0\n",
    "torch.cuda.set_device(device_id)\n",
    "\n",
    "model = PyramidDiTForVideoGeneration(\n",
    "    model_path,\n",
    "    model_dtype,\n",
    "    model_name=model_name,\n",
    "    model_variant=variant,\n",
    ")\n",
    "\n",
    "model.vae.to(\"cuda\")\n",
    "model.dit.to(\"cuda\")\n",
    "model.text_encoder.to(\"cuda\")\n",
    "\n",
    "model.vae.enable_tiling()\n",
    "\n",
    "if model_dtype == \"bf16\":\n",
    "    torch_dtype = torch.bfloat16 \n",
    "elif model_dtype == \"fp16\":\n",
    "    torch_dtype = torch.float16\n",
    "else:\n",
    "    torch_dtype = torch.float32"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Text-to-Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"shoulder and full head portrait of a beautiful 19 year old girl, brunette, smiling, stunning, highly detailed, glamour lighting, HDR, photorealistic, hyperrealism, octane render, unreal engine\"\n",
    "\n",
    "# now support 3 aspect ratios\n",
    "resolution_dict = {\n",
    "    '1:1' : (1024, 1024),\n",
    "    '5:3' : (1280, 768),\n",
    "    '3:5' : (768, 1280),\n",
    "}\n",
    "\n",
    "ratio = '1:1'   # 1:1, 5:3, 3:5\n",
    "\n",
    "width, height = resolution_dict[ratio]\n",
    "\n",
    "\n",
    "with torch.no_grad(), torch.cuda.amp.autocast(enabled=True if model_dtype != 'fp32' else False, dtype=torch_dtype):\n",
    "    images = model.generate(\n",
    "        prompt=prompt,\n",
    "        num_inference_steps=[20, 20, 20],\n",
    "        height=height,\n",
    "        width=width,\n",
    "        temp=1,\n",
    "        guidance_scale=9.0,        \n",
    "        output_type=\"pil\",\n",
    "        save_memory=False, \n",
    "    )\n",
    "\n",
    "display(images[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
